{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# TensorFlow2 推理\n",
    "\n",
    "参考：[migrating_checkpoints](https://www.tensorflow.org/guide/migrate/migrating_checkpoints)\n",
    "\n",
    "下面以模型 [resnet_v2_50](http://download.tensorflow.org/models/resnet_v2_50_2017_04_14.tar.gz) 为例展示。\n",
    "\n",
    "需要克隆项目 [models](https://github.com/tensorflow/models)，然后执行如下操作。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "try:\n",
    "    tf1 = tf.compat.v1\n",
    "except (ImportError, AttributeError):\n",
    "    tf1 = tf\n",
    "tf.get_logger().setLevel('ERROR')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "切换到 `models/research/slim` 目录下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%cd /media/pc/data/lxw/ai/tasks/models/research/slim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%bash \n",
    "pip install --upgrade tf_slim tf-keras -i https://pypi.tuna.tsinghua.edu.cn/simple"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "将 TF1 升级为 TF2："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nets import resnet_v2\n",
    "import tf_slim as slim\n",
    "\n",
    "\n",
    "class ResnetV2_50(tf.Module):\n",
    "    @tf.function(input_signature=[tf.TensorSpec([None, 299, 299, 3], tf.float32, name=\"data\")])\n",
    "    @tf1.keras.utils.track_tf1_style_variables\n",
    "    def __call__(self, data):\n",
    "        with slim.arg_scope(resnet_v2.resnet_arg_scope()):\n",
    "            logits, end_points = resnet_v2.resnet_v2_50(\n",
    "                data, \n",
    "                num_classes=1001,\n",
    "                global_pool=True,\n",
    "                is_training=False,\n",
    "                scope=\"resnet_v2_50\"\n",
    "            )\n",
    "        del end_points\n",
    "        return tf.nn.softmax(logits) "
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "预处理："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "import numpy as np\n",
    "from nets import resnet_v2\n",
    "from tvm_book.data.classification import ImageFolderDataset\n",
    "import tf_slim as slim\n",
    "import tensorflow as tf\n",
    "\n",
    "\n",
    "@tf.function\n",
    "def preprocessing(\n",
    "    image,\n",
    "    use_grayscale=False,\n",
    "    central_fraction=0.875,\n",
    "    central_crop=True,\n",
    "    height=299,\n",
    "    width=299,\n",
    "    mean: tuple[float, ...] = (0.485, 0.456, 0.406),\n",
    "    std: tuple[float, ...] = (1, 1, 1)\n",
    "):\n",
    "    # image = tf.constant(image)\n",
    "    if image.dtype != tf.float32:\n",
    "        image = tf.image.convert_image_dtype(image, dtype=tf.float32)\n",
    "    if use_grayscale:\n",
    "        image = tf.image.rgb_to_grayscale(image)\n",
    "    if central_crop and central_fraction:\n",
    "        image = tf.image.central_crop(image, central_fraction=central_fraction)\n",
    "    if height and width:\n",
    "        image = tf.expand_dims(image, 0)\n",
    "        image = tf.image.resize(image, [height, width],\n",
    "                                method='bilinear',\n",
    "                                preserve_aspect_ratio=False,\n",
    "                                antialias=False)\n",
    "        image = tf.squeeze(image, [0])\n",
    "    image = tf.subtract(image, mean)\n",
    "    image = tf.divide(image, std)\n",
    "    return image\n",
    "\n",
    "\n",
    "# 预处理\n",
    "root = \"/media/pc/data/lxw/home/data/datasets/ILSVRC/val\"\n",
    "valset = ImageFolderDataset(root)\n",
    "image, label_id = valset[1001]\n",
    "model_dir = 'temp/resnet_v2_50'\n",
    "# remove_dir(model_dir)\n",
    "processed_image = preprocessing(\n",
    "    image,\n",
    "    use_grayscale=False,\n",
    "    central_fraction=0.875,\n",
    "    central_crop=True,\n",
    "    height=299,\n",
    "    width=299,\n",
    "    mean=(0.485, 0.456, 0.406),\n",
    "    std=(1, 1, 1)\n",
    ")\n",
    "np_processed_images = np.expand_dims(processed_image.numpy(), axis=0)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "前向推理："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = ResnetV2_50()\n",
    "model(tf.ones(shape=(1, 299, 299, 3), dtype=tf.float32))\n",
    "ckpt = tf.train.Checkpoint(model=model)\n",
    "ckpt.restore(\".temp/checkpoints/resnet_v2_50.ckpt\") # 更新模型参数\n",
    "outputs = model(np_processed_images)\n",
    "outputs = outputs.numpy()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "打印标签信息："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tvm_book.data.imagenet.classification import ImageNet1kAttr\n",
    "\n",
    "imagenet1k_attr = ImageNet1kAttr()\n",
    "sorted_inds = outputs[0].argsort()[::-1]\n",
    "topk = 5\n",
    "print(f\"真实标签：{imagenet1k_attr.classes_long[label_id]}\")\n",
    "for sorted_ind in sorted_inds[:topk]:\n",
    "    label = imagenet1k_attr.classes_long[sorted_ind-1]\n",
    "    print(f\"{sorted_ind-1}: {label.ljust(38)}\\t{outputs[0, sorted_ind]}\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "将其模型和参数保存下来："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "module_with_signature_path = 'temp/module_with_signature'\n",
    "call = model.__call__.get_concrete_function(tf.TensorSpec([1, 299, 299, 3], tf.float32, name=\"data\"))\n",
    "tf.saved_model.save(model, module_with_signature_path, signatures=call)\n",
    "imported_with_signatures = tf.saved_model.load(module_with_signature_path)\n",
    "infer = imported_with_signatures.signatures['serving_default']\n",
    "labeling = infer(tf.constant(np_processed_images))['output_0']\n",
    "# gdef = model.__call__.get_concrete_function().graph.as_graph_def(add_shapes=True)\n",
    "# gdef_ops = list(set([n.op for n in gdef.node]))\n",
    "# gdef = infer.graph.as_graph_def(add_shapes=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tvmz",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
